# -*- coding: utf-8 -*-
# ===========================================
# @Time    : 2021/9/23 17:01 
# @Author  : shutao
# @FileName: stgcn_s.py
# @remark  : 
# 
# @Software: PyCharm
# Github 　： https://github.com/NameLacker
# ===========================================

import paddle
import os

from modules.exp import STGCNExp as MyExp


class Exp(MyExp):
    def __init__(self):
        super(Exp, self).__init__()
        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]

    def run_train(self, data, loss_func):
        img1, img2, label = data
        img1 = paddle.cast(img1, paddle.float32)
        img2 = paddle.cast(img2, paddle.float32)
        label = paddle.cast(label, paddle.int64)

        feature_vector_1 = self.model(img1)
        feature_vector_2 = self.model(img2)
        return loss_func((feature_vector_1, feature_vector_2), label)
